import torch
import torch.nn as nn


#  RNN层送入批量数据
def test():
    # 词向量维度 128, 隐藏向量维度 256
    rnn = nn.RNN(input_size=128, hidden_size=256)
    # 第一个数字: 表示句子长度,也就是词语个数
    # 第二个数字: 批量个数，也就是句子的个数
    # 第三个数字: 词向量维度
    inputs = torch.randn(5, 32, 128)
    hn = torch.zeros(1, 32, 256)
    # 获取输出结果
    output, hn = rnn(inputs, hn)
    print("输出向量的维度：\n", output.shape)
    print("隐含层输出的维度：\n", hn.shape)


if __name__ == '__main__':
    test()
